Skip to content

Conversation

@AlexanderYastrebov
Copy link
Contributor

Reduce allocations by eliminating byte reader, hand-rolled decoding and reusing message structs.

Comment on lines +121 to +124
func (msg *MessageInitiation) unmarshal(b []byte) error {
if len(b) < MessageInitiationSize {
return errMessageTooShort
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe name it reset.

Also the size of the packet is checked earlier

case MessageInitiationType:
if len(packet) != MessageInitiationSize {
continue
}
case MessageResponseType:
if len(packet) != MessageResponseSize {
continue
}
case MessageCookieReplyType:
if len(packet) != MessageCookieReplySize {
continue
}

so this check could be removed as well as error result check at the callsite.

// unmarshal packet

var reply MessageCookieReply
reader := bytes.NewReader(elem.packet)
Copy link
Contributor Author

@AlexanderYastrebov AlexanderYastrebov Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here byte reader can be replaced with https://pkg.go.dev/encoding/binary#Decode added in 1.23 golang/go#60023 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you intend to rework this around binary.Decode instead of this commit, let me know. I'm all for reducing allocations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a note, nothing (except zero-copy) can beat copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$ go test ./device/ -c
$ go tool objdump -S -s BenchmarkMessageInitiationUnmarshal.func2 device.test
TEXT golang.zx2c4.com/wireguard/device.BenchmarkMessageInitiationUnmarshal.func2(SB) /home/ayastrebov/src/github.com/WireGuard/wireguard-go/device/noise-protocol_test.go
        b.Run("unmarshal", func(b *testing.B) {
  0x5adb60              493b6610                CMPQ SP, 0x10(R14)
  0x5adb64              0f8638010000            JBE 0x5adca2
  0x5adb6a              55                      PUSHQ BP
  0x5adb6b              4889e5                  MOVQ SP, BP
  0x5adb6e              4883ec30                SUBQ $0x30, SP
  0x5adb72              488b5a08                MOVQ 0x8(DX), BX
  0x5adb76              48895c2428              MOVQ BX, 0x28(SP)
                b.ReportAllocs()
  0x5adb7b              90                      NOPL
        b.Run("unmarshal", func(b *testing.B) {
  0x5adb7c              488b5210                MOVQ 0x10(DX), DX
  0x5adb80              4889542420              MOVQ DX, 0x20(SP)
        b.showAllocResult = true
  0x5adb85              c6800202000001          MOVB $0x1, 0x202(AX)
                for range b.N {
  0x5adb8c              488bb0c0010000          MOVQ 0x1c0(AX), SI
  0x5adb93              eb0b                    JMP 0x5adba0
  0x5adb95              48ffce                  DECQ SI
  0x5adb98              0f1f840000000000        NOPL 0(AX)(AX*1)
  0x5adba0              4885f6                  TESTQ SI, SI
  0x5adba3              0f8ef3000000            JLE 0x5adc9c
                        _ = msgSink.unmarshal(packet)
  0x5adba9              90                      NOPL
        b.Run("unmarshal", func(b *testing.B) {
  0x5adbaa              4881fa94000000          CMPQ DX, $0x94
        if len(b) < MessageInitiationSize {
  0x5adbb1              7ce2                    JL 0x5adb95
                for range b.N {
  0x5adbb3              4889742418              MOVQ SI, 0x18(SP)
        copy(msg.Ephemeral[:], b[8:])
  0x5adbb8              488d7b08                LEAQ 0x8(BX), DI
        return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
  0x5adbbc              448b03                  MOVL 0(BX), R8
        msg.Type = binary.LittleEndian.Uint32(b)
  0x5adbbf              4489059aec2700          MOVL R8, golang.zx2c4.com/wireguard/device.msgSink(SB)
        return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
  0x5adbc6              448b4304                MOVL 0x4(BX), R8
        msg.Sender = binary.LittleEndian.Uint32(b[4:])
  0x5adbca              44890593ec2700          MOVL R8, golang.zx2c4.com/wireguard/device.msgSink+4(SB)
        copy(msg.Ephemeral[:], b[8:])
  0x5adbd1              488d0590ec2700          LEAQ golang.zx2c4.com/wireguard/device.msgSink+8(SB), AX
  0x5adbd8              4839f8                  CMPQ AX, DI
  0x5adbdb              741c                    JE 0x5adbf9
  0x5adbdd              4889fb                  MOVQ DI, BX
  0x5adbe0              b920000000              MOVL $0x20, CX
  0x5adbe5              e856f8ecff              CALL runtime.memmove(SB)
        b.Run("unmarshal", func(b *testing.B) {
  0x5adbea              488b542420              MOVQ 0x20(SP), DX
        copy(msg.Static[:], b[8+len(msg.Ephemeral):])
  0x5adbef              488b5c2428              MOVQ 0x28(SP), BX
                for range b.N {
  0x5adbf4              488b742418              MOVQ 0x18(SP), SI
        copy(msg.Static[:], b[8+len(msg.Ephemeral):])
  0x5adbf9              488d7b28                LEAQ 0x28(BX), DI
  0x5adbfd              488d0584ec2700          LEAQ golang.zx2c4.com/wireguard/device.msgSink+40(SB), AX
  0x5adc04              4839f8                  CMPQ AX, DI
  0x5adc07              741c                    JE 0x5adc25
  0x5adc09              4889fb                  MOVQ DI, BX
  0x5adc0c              b930000000              MOVL $0x30, CX
  0x5adc11              e82af8ecff              CALL runtime.memmove(SB)
        b.Run("unmarshal", func(b *testing.B) {
  0x5adc16              488b542420              MOVQ 0x20(SP), DX
        copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):])
  0x5adc1b              488b5c2428              MOVQ 0x28(SP), BX
                for range b.N {
  0x5adc20              488b742418              MOVQ 0x18(SP), SI
        copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):])
  0x5adc25              488d7b58                LEAQ 0x58(BX), DI
  0x5adc29              488d0588ec2700          LEAQ golang.zx2c4.com/wireguard/device.msgSink+88(SB), AX
  0x5adc30              4839f8                  CMPQ AX, DI
  0x5adc33              741f                    JE 0x5adc54
  0x5adc35              4889fb                  MOVQ DI, BX
  0x5adc38              b91c000000              MOVL $0x1c, CX
  0x5adc3d              0f1f00                  NOPL 0(AX)
  0x5adc40              e8fbf7ecff              CALL runtime.memmove(SB)
        b.Run("unmarshal", func(b *testing.B) {
  0x5adc45              488b542420              MOVQ 0x20(SP), DX
        copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):])
  0x5adc4a              488b5c2428              MOVQ 0x28(SP), BX
                for range b.N {
  0x5adc4f              488b742418              MOVQ 0x18(SP), SI
        copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):])
  0x5adc54              488d7b74                LEAQ 0x74(BX), DI
  0x5adc58              4c8d0575ec2700          LEAQ golang.zx2c4.com/wireguard/device.msgSink+116(SB), R8
  0x5adc5f              90                      NOPL
  0x5adc60              4939f8                  CMPQ R8, DI
  0x5adc63              740b                    JE 0x5adc70
  0x5adc65              0f104374                MOVUPS 0x74(BX), X0
  0x5adc69              0f110564ec2700          MOVUPS X0, golang.zx2c4.com/wireguard/device.msgSink+116(SB)
        copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):])
  0x5adc70              488dbb84000000          LEAQ 0x84(BX), DI
  0x5adc77              4c8d0566ec2700          LEAQ golang.zx2c4.com/wireguard/device.msgSink+132(SB), R8
  0x5adc7e              6690                    NOPW
  0x5adc80              4939f8                  CMPQ R8, DI
  0x5adc83              0f840cffffff            JE 0x5adb95
  0x5adc89              0f108384000000          MOVUPS 0x84(BX), X0
  0x5adc90              0f11054dec2700          MOVUPS X0, golang.zx2c4.com/wireguard/device.msgSink+132(SB)
  0x5adc97              e9f9feffff              JMP 0x5adb95
        })
  0x5adc9c              4883c430                ADDQ $0x30, SP
  0x5adca0              5d                      POPQ BP
  0x5adca1              c3                      RET
        b.Run("unmarshal", func(b *testing.B) {
  0x5adca2              4889442408              MOVQ AX, 0x8(SP)
  0x5adca7              e894caecff              CALL runtime.morestack.abi0(SB)
  0x5adcac              488b442408              MOVQ 0x8(SP), AX
  0x5adcb1              e9aafeffff              JMP golang.zx2c4.com/wireguard/device.BenchmarkMessageInitiationUnmarshal.func2(SB)

msgCookieReply MessageCookieReply
msgInitiation MessageInitiation
msgResponse MessageResponse
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why move the scope of these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember, reverted.

@zx2c4
Copy link
Member

zx2c4 commented May 4, 2025

This is interesting to me. Do you have any stats showing the allocations saved?

Reduce allocations by eliminating byte reader, hand-rolled decoding and
reusing message structs.

Signed-off-by: Alexander Yastrebov <yastrebov.alex@gmail.com>
@AlexanderYastrebov AlexanderYastrebov force-pushed the device/reduce-RoutineHandshake-allocs branch from e585270 to 12b6daa Compare May 9, 2025 13:32
@AlexanderYastrebov
Copy link
Contributor Author

This is interesting to me. Do you have any stats showing the allocations saved?

I've added a commit with a (synthetic) benchmark.
Of course it would be beneficial to benchmark RoutineHandshake itself but its a bit hard to do, I can think about running client and server benchmark in separate processes (similar to e.g. this golang/go#61390)

@AlexanderYastrebov
Copy link
Contributor Author

AlexanderYastrebov commented May 9, 2025

BTW it looks like device.log.Verbosef with arguments allocates and its a bit unfortunate because this happens on error (hence untrusted) paths:

$ go test ./device/ -c && go tool objdump -S -s RoutineHandshake device.test | grep runtime.newobject | wc -l
11
$ # remove arguments to device.log.Verbosef
$ go test ./device/ -c && go tool objdump -S -s RoutineHandshake device.test | grep runtime.newobject | wc -l
6

One idea is to wrap calls like

if device.log.IsVerbose {
  device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
}

but that makes verbose logging, well, verbose 🙈

Another option might be to use https://go.dev/blog/slog
I've sketched the change to use slog and move message variables out of the loop and it seems to reduce to three allocations outside of the loop:

Details
diff --git a/device/receive.go b/device/receive.go
index 1392957..4029dde 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -12,6 +12,8 @@ import (
 	"sync"
 	"time"
 
+	"log/slog"
+
 	"golang.org/x/crypto/chacha20poly1305"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv6"
@@ -270,11 +272,16 @@ func (device *Device) RoutineDecryption(id int) {
  */
 func (device *Device) RoutineHandshake(id int) {
 	defer func() {
-		device.log.Verbosef("Routine: handshake worker %d - stopped", id)
+		slog.Debug("Routine: handshake worker - stopped", "id", id)
 		device.queue.encryption.wg.Done()
 	}()
-	device.log.Verbosef("Routine: handshake worker %d - started", id)
+	slog.Debug("Routine: handshake worker - started", "id", id)
 
+	var (
+		msgCookieReply       MessageCookieReply
+		msgMessageInitiation MessageInitiation
+		msgMessageResponse   MessageResponse
+	)
 	for elem := range device.queue.handshake.c {
 
 		// handle cookie fields and ratelimiting
@@ -285,16 +292,15 @@ func (device *Device) RoutineHandshake(id int) {
 
 			// unmarshal packet
 
-			var reply MessageCookieReply
-			err := reply.unmarshal(elem.packet)
+			err := msgCookieReply.unmarshal(elem.packet)
 			if err != nil {
-				device.log.Verbosef("Failed to decode cookie reply")
+				slog.Debug("Failed to decode cookie reply")
 				goto skip
 			}
 
 			// lookup peer from index
 
-			entry := device.indexTable.Lookup(reply.Receiver)
+			entry := device.indexTable.Lookup(msgCookieReply.Receiver)
 
 			if entry.peer == nil {
 				goto skip
@@ -303,9 +309,9 @@ func (device *Device) RoutineHandshake(id int) {
 			// consume reply
 
 			if peer := entry.peer; peer.isRunning.Load() {
-				device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
-				if !peer.cookieGenerator.ConsumeReply(&reply) {
-					device.log.Verbosef("Could not decrypt invalid cookie response")
+				slog.Debug("Receiving cookie response", "endpoint", elem.endpoint.DstToString())
+				if !peer.cookieGenerator.ConsumeReply(&msgCookieReply) {
+					slog.Debug("Could not decrypt invalid cookie response")
 				}
 			}
 
@@ -316,7 +322,7 @@ func (device *Device) RoutineHandshake(id int) {
 			// check mac fields and maybe ratelimit
 
 			if !device.cookieChecker.CheckMAC1(elem.packet) {
-				device.log.Verbosef("Received packet with invalid mac1")
+				slog.Debug("Received packet with invalid mac1")
 				goto skip
 			}
 
@@ -339,7 +345,7 @@ func (device *Device) RoutineHandshake(id int) {
 			}
 
 		default:
-			device.log.Errorf("Invalid packet ended up in the handshake queue")
+			slog.Error("Invalid packet ended up in the handshake queue")
 			goto skip
 		}
 
@@ -350,18 +356,17 @@ func (device *Device) RoutineHandshake(id int) {
 
 			// unmarshal
 
-			var msg MessageInitiation
-			err := msg.unmarshal(elem.packet)
+			err := msgMessageInitiation.unmarshal(elem.packet)
 			if err != nil {
-				device.log.Errorf("Failed to decode initiation message")
+				slog.Error("Failed to decode initiation message")
 				goto skip
 			}
 
 			// consume initiation
 
-			peer := device.ConsumeMessageInitiation(&msg)
+			peer := device.ConsumeMessageInitiation(&msgMessageInitiation)
 			if peer == nil {
-				device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
+				slog.Debug("Received invalid initiation message", "endpoint", elem.endpoint.DstToString())
 				goto skip
 			}
 
@@ -373,7 +378,7 @@ func (device *Device) RoutineHandshake(id int) {
 			// update endpoint
 			peer.SetEndpointFromPacket(elem.endpoint)
 
-			device.log.Verbosef("%v - Received handshake initiation", peer)
+			slog.Debug("Received handshake initiation", "peer", peer)
 			peer.rxBytes.Add(uint64(len(elem.packet)))
 
 			peer.SendHandshakeResponse()
@@ -382,25 +387,24 @@ func (device *Device) RoutineHandshake(id int) {
 
 			// unmarshal
 
-			var msg MessageResponse
-			err := msg.unmarshal(elem.packet)
+			err := msgMessageResponse.unmarshal(elem.packet)
 			if err != nil {
-				device.log.Errorf("Failed to decode response message")
+				slog.Error("Failed to decode response message")
 				goto skip
 			}
 
 			// consume response
 
-			peer := device.ConsumeMessageResponse(&msg)
+			peer := device.ConsumeMessageResponse(&msgMessageResponse)
 			if peer == nil {
-				device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
+				slog.Debug("Received invalid response message", "endpoint", elem.endpoint.DstToString())
 				goto skip
 			}
 
 			// update endpoint
 			peer.SetEndpointFromPacket(elem.endpoint)
 
-			device.log.Verbosef("%v - Received handshake response", peer)
+			slog.Debug("Received handshake response", "peer", peer)
 			peer.rxBytes.Add(uint64(len(elem.packet)))
 
 			// update timers
@@ -413,7 +417,7 @@ func (device *Device) RoutineHandshake(id int) {
 			err = peer.BeginSymmetricSession()
 
 			if err != nil {
-				device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
+				slog.Error("Failed to derive keypair", "peer", peer, "error", err)
 				goto skip
 			}
$ go test ./device/ -c && go tool objdump -S -s RoutineHandshake device.test | grep runtime.newobject | wc -l
3
...
		msgCookieReply       MessageCookieReply
  0x5be6bc		488d05bd3d0800		LEAQ 0x83dbd(IP), AX		
  0x5be6c3		e8f887e5ff		CALL runtime.newobject(SB)	
  0x5be6c8		4889842470020000	MOVQ AX, 0x270(SP)		
		msgMessageInitiation MessageInitiation
  0x5be6d0		488d05e90b0900		LEAQ 0x90be9(IP), AX		
  0x5be6d7		e8e487e5ff		CALL runtime.newobject(SB)	
  0x5be6dc		4889842468020000	MOVQ AX, 0x268(SP)		
		msgMessageResponse   MessageResponse
  0x5be6e4		488d05f50c0900		LEAQ 0x90cf5(IP), AX		
  0x5be6eb		e8d087e5ff		CALL runtime.newobject(SB)	
  0x5be6f0		4889842460020000	MOVQ AX, 0x260(SP)		
	for elem := range device.queue.handshake.c {
  0x5be6f8		488b8c2490020000	MOVQ 0x290(SP), CX	
  0x5be700		488b9130020000		MOVQ 0x230(CX), DX	
...

I can propose another PR to swap device logger for slog logger.

goos: linux
goarch: amd64
pkg: golang.zx2c4.com/wireguard/device
                                         │      -      │
                                         │   sec/op    │
MessageInitiationUnmarshal/binary.Read-8   1.508µ ± 2%
MessageInitiationUnmarshal/unmarshal-8     12.66n ± 2%
geomean                                    138.1n

                                         │      -       │
                                         │     B/op     │
MessageInitiationUnmarshal/binary.Read-8   208.0 ± 0%
MessageInitiationUnmarshal/unmarshal-8     0.000 ± 0%
geomean                                               ¹
¹ summaries must be >0 to compute geomean

                                         │      -       │
                                         │  allocs/op   │
MessageInitiationUnmarshal/binary.Read-8   2.000 ± 0%
MessageInitiationUnmarshal/unmarshal-8     0.000 ± 0%
geomean                                               ¹
¹ summaries must be >0 to compute geomean

Signed-off-by: Alexander Yastrebov <yastrebov.alex@gmail.com>
@AlexanderYastrebov AlexanderYastrebov force-pushed the device/reduce-RoutineHandshake-allocs branch from 12b6daa to 39abcab Compare May 9, 2025 15:49
@zx2c4
Copy link
Member

zx2c4 commented May 15, 2025

slog fixes this issue because the arguments are evaluated lazily somehow?

@zx2c4
Copy link
Member

zx2c4 commented May 15, 2025

I merged this. Let's take up the slog conversation on a new PR.

Thanks for the patch!

@zx2c4 zx2c4 closed this May 15, 2025
@AlexanderYastrebov AlexanderYastrebov deleted the device/reduce-RoutineHandshake-allocs branch May 15, 2025 15:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Development

Successfully merging this pull request may close these issues.

2 participants